Commit 67d1f53d authored by Ngan Thi Dong's avatar Ngan Thi Dong
Browse files

add final result for generalize and denovo, add code to get top predicted proteins

parent c7c056a1
This diff is collapsed.
train_rate,test_rate,auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1
1,1,0.4982635555555553+-0.055604801730661485,0.499900404741971+-0.05142140108015179,0.49819500000000005+-0.03909777583188076,0.49819500000000005+-0.03909777583188076,0.49819500000000005+-0.03909777583188076,4.73+-2.1533926720410292,0.49820000000000014+-0.03909936060858285,0.49820000000000014+-0.03909936060858285,0.49820000000000014+-0.03909936060858285
1,2,0.5098964444444446+-0.05195553868168151,0.3471158579081102+-0.04808121379373515,0.325469+-0.08583921445936002,0.6627310000000002+-0.042911564163987316,0.5503079999999999+-0.05722113714354162,3.52+-1.9363883907935413,0.32546666666666657+-0.08583384465867115,0.32546666666666657+-0.08583384465867115,0.32546666666666657+-0.08583384465867115
1,5,0.5076174222222222+-0.052393891385264235,0.17804197841344804+-0.03211548857625681,0.15392899999999993+-0.07869020306874289,0.8307779999999996+-0.015723889976720133,0.7179730000000001+-0.026207471663630577,1.99+-1.3747363383572864,0.15391743929359827+-0.07867257183945442,0.15393333333333337+-0.0786897988588384,0.15392535991140646+-0.07868111747249373
1,10,0.5104775999999999+-0.050654402883895,0.10081962302676561+-0.021084589023987757,0.09013699999999998+-0.04596053667006076,0.9090169999999996+-0.004595586034446535,0.8345709999999994+-0.00835014125629022,1.38+-0.9568698971124547,0.09013333333333336+-0.04596114300870537,0.09013333333333336+-0.04596114300870537,0.09013333333333336+-0.04596114300870537
2,1,0.5968844444444443+-0.030718475077892635,0.6118394238709726+-0.03277214467455567,0.565807+-0.025978957850537422,0.5657399999999999+-0.02603672406429041,0.565773+-0.02600615256049999,7.01+-0.9949371839468064,0.5657637969094923+-0.025997598491192056,0.5658000000000001+-0.025968442386866426,0.565781838316722+-0.02598244872808013
2,2,0.6048637777777778+-0.024589594387629657,0.4595771055880749+-0.036785330672138856,0.48640399999999984+-0.041541232336077843,0.7431960000000002+-0.020763207459349814,0.6575939999999997+-0.027687101040015014,5.52+-1.3962807740565648,0.4863999999999997+-0.04153627597922353,0.4863999999999997+-0.04153627597922353,0.4863999999999997+-0.04153627597922353
2,5,0.6065846222222223+-0.023507736951076087,0.26932252168289134+-0.0314886500494005,0.3154670000000001+-0.045592095926816095,0.8630929999999997+-0.009111566879521875,0.7718180000000003+-0.015200956417278493,3.59+-1.1755424279880329,0.3154666666666665+-0.04558635273363679,0.3154666666666665+-0.04558635273363679,0.3154666666666665+-0.04558635273363679
2,10,0.607040711111111+-0.022445431315773617,0.16145644419958202+-0.02234351010189305,0.2050730000000001+-0.037716293176822145,0.9205130000000001+-0.0037739807895642513,0.8554729999999998+-0.006856177579380515,2.52+-1.0244998779892551,0.2050666666666667+-0.03771377585032944,0.2050666666666667+-0.03771377585032944,0.2050666666666667+-0.03771377585032942
5,1,0.4551693333333335+-0.06261426420375615,0.4660972615029609+-0.06701319722193214,0.4590700000000001+-0.04890987528096958,0.4590700000000001+-0.04890987528096958,0.4590700000000001+-0.04890987528096958,3.24+-3.3678479775666847,0.4590666666666665+-0.04889917063600248,0.4590666666666665+-0.04889917063600248,0.4590666666666665+-0.04889917063600248
5,2,0.46117711111111115+-0.05998676617058915,0.3165430212238304+-0.062487608170576296,0.251538+-0.09474721819663097,0.6257619999999998+-0.047378477772085516,0.5010220000000001+-0.06316648728558524,2.68+-2.9081953166869656,0.2515333333333334+-0.09475045587694494,0.2515333333333334+-0.09475045587694494,0.2515333333333334+-0.09475045587694494
5,5,0.45806231111111095+-0.06140253139264048,0.16193844579116404+-0.046114019091152966,0.08846300000000001+-0.10661980365298003,0.8176969999999993+-0.021324471646444108,0.6961630000000001+-0.035534863599006546,1.88+-2.2417850030723283,0.08846666666666664+-0.10662021905394455,0.08846666666666664+-0.10662021905394455,0.08846666666666665+-0.10662021905394455
5,10,0.4625836+-0.05911649015691051,0.09061465605311728+-0.02958976337774055,0.05913200000000002+-0.06880302301498095,0.9059120000000012+-0.006885757474671895,0.8289360000000008+-0.012500156159024564,1.07+-1.3729894391436523,0.05913333333333336+-0.06879699605845208,0.05913333333333336+-0.06879699605845208,0.05913333333333336+-0.06879699605845208
10,1,0.6236457777777776+-0.05917610343690526,0.6317776506885981+-0.03438202308039972,0.6034740000000002+-0.021945462492278432,0.6032730000000002+-0.022208794902020244,0.6033730000000003+-0.02206997442227788,7.92+-1.2056533498481234,0.603351876379691+-0.02210125745606409,0.6034666666666668+-0.021959154001109513,0.6034090808416391+-0.02202813768797601
10,2,0.6259648888888889+-0.05927143991360509,0.4693506057539267+-0.03228782073309199,0.4649270000000002+-0.030373981810095284,0.7324729999999997+-0.015180855410680916,0.6432929999999998+-0.02024402753900518,6.39+-1.2953377937819917,0.46493333333333353+-0.03036986810786128,0.46493333333333353+-0.03036986810786128,0.46493333333333353+-0.030369868107861277
10,5,0.637388177777778+-0.06102004274240101,0.28060519087655267+-0.026200508126947117,0.3113320000000001+-0.027051225036955347,0.8622679999999998+-0.005415457136752176,0.7704360000000005+-0.009018741819123105,4.26+-1.5467385040788244,0.311333333333333+-0.02705549851693736,0.311333333333333+-0.02705549851693736,0.311333333333333+-0.02705549851693736
10,10,0.632619911111111+-0.05848415586888302,0.16900065575011283+-0.019681785053793102,0.19807000000000005+-0.029277067817662344,0.9198100000000008+-0.0029277124175710934,0.8542009999999991+-0.005319388968669233,3.07+-1.3874797295816608,0.19806666666666664+-0.029277143151460516,0.19806666666666664+-0.029277143151460516,0.19806666666666664+-0.029277143151460516
This diff is collapsed.
train_rate,test_rate,auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1
1,1,0.6501571111111112+-0.027266077736101152,0.6518105876314051+-0.02817758777389157,0.6171310000000001+-0.02238215224235597,0.6171310000000001+-0.02238215224235597,0.6171310000000001+-0.02238215224235597,7.82+-1.3067516979135707,0.6171333333333332+-0.02237568124350877,0.6171333333333332+-0.02237568124350877,0.6171333333333332+-0.02237568124350877
1,2,0.6485598888888888+-0.02640937001202601,0.4862297004498328+-0.0304436526164803,0.496+-0.022330862052325703,0.7479999999999998+-0.011165455655726738,0.6640010000000001+-0.0148955932745225,6.51+-1.6461773901982741,0.49599999999999994+-0.0223308456325923,0.49599999999999994+-0.0223308456325923,0.49599999999999994+-0.0223308456325923
1,5,0.6587823555555555+-0.023445559711903752,0.29481406503459134+-0.024784592188969397,0.326595+-0.026384087533966386,0.8652719999999996+-0.0052713580792809085,0.77549+-0.0087793792491269,4.34+-1.7390802166662704,0.32651052631578936+-0.026345067280641832,0.3265999999999999+-0.02639014967748383,0.32655496688741714+-0.026365597539465106
1,10,0.6541743333333333+-0.02159103345755899,0.17455792838619114+-0.017374610391412984,0.21159299999999995+-0.023789885476815563,0.9210480000000006+-0.0024308632211623928,0.8565539999999995+-0.004350595821264035,3.24+-1.4705101155721427,0.21137853021951686+-0.02379769466393611,0.2116+-0.023788325801628922,0.2114877648593159+-0.023786703765479568
2,1,0.6723506666666667+-0.03110414010472434,0.6753997237899888+-0.03365240960083135,0.630532+-0.017997343581762294,0.630532+-0.017997343581762294,0.630532+-0.017997343581762294,8.15+-1.3883443376914817,0.630533333333333+-0.01799209703051748,0.630533333333333+-0.01799209703051748,0.630533333333333+-0.01799209703051748
2,2,0.6711551111111114+-0.030263647096887343,0.512636170605297+-0.031461059709777466,0.5011329999999999+-0.023887172519994917,0.7505669999999999+-0.011947351631219372,0.6674209999999999+-0.015932565989193327,6.8+-1.7029386365926393,0.5011333333333332+-0.023889653734526076,0.5011333333333332+-0.023889653734526076,0.5011333333333332+-0.023889653734526076
2,5,0.6792691555555556+-0.026732458670978156,0.3182105986789411+-0.026383613102603707,0.34267199999999987+-0.03252484613337932,0.8685280000000002+-0.006503415717913154,0.7808920000000004+-0.010838603969146574,4.68+-1.2796874618437117,0.3426666666666669+-0.03252349578040124,0.3426666666666669+-0.03252349578040124,0.3426666666666669+-0.03252349578040125
2,10,0.6766002666666666+-0.025302399236806876,0.1952634073937991+-0.022218064249176158,0.23473199999999994+-0.028459405756269762,0.923439+-0.0028572327521572275,0.8608329999999996+-0.005183127530748215,3.34+-0.9614572273377525,0.2346559139784945+-0.02845729955605227,0.2347333333333332+-0.028461201661208874,0.2346939890710381+-0.028456612990605017
5,1,0.6677391111111111+-0.02300276046129044,0.6837618973681809+-0.02990886422711015,0.6317339999999999+-0.01867364570725279,0.6317339999999999+-0.01867364570725279,0.6317339999999999+-0.01867364570725279,8.77+-1.173499041328965,0.6317333333333331+-0.01866952359089838,0.6317333333333331+-0.01866952359089838,0.6317333333333331+-0.01866952359089838
5,2,0.6681717777777776+-0.021477929149320008,0.5313068070016248+-0.02650520566884531,0.4959360000000001+-0.016269133474158974,0.7479640000000003+-0.008139748399060018,0.6639599999999998+-0.010854556646864943,7.9+-1.3527749258468684,0.49593333333333317+-0.016272539923305015,0.49593333333333317+-0.016272539923305015,0.49593333333333317+-0.016272539923305015
5,5,0.6757592888888891+-0.01882447555941237,0.3387339468947492+-0.023796022717995295,0.36166200000000004+-0.024915110997143882,0.8723379999999996+-0.004983448203804262,0.7872240000000004+-0.008305758484328804,6.25+-1.3444329659748753,0.3616666666666665+-0.02491541245449847,0.3616666666666665+-0.02491541245449847,0.3616666666666665+-0.024915412454498475
5,10,0.6738360888888885+-0.01605932103824388,0.21765589302570912+-0.018068790503237942,0.25420299999999996+-0.02157632246236603,0.9254229999999997+-0.0021591597902887935,0.8644009999999995+-0.003928752855550981,4.92+-1.390539463661496,0.25420000000000004+-0.021574779514774007,0.25420000000000004+-0.021574779514774007,0.25420000000000004+-0.021574779514774007
10,1,0.6594279999999999+-0.018948523801911452,0.6821420039281096+-0.028335886578054192,0.6186+-0.01578679828210901,0.6186+-0.01578679828210901,0.6186+-0.01578679828210901,8.96+-1.1306635220082049,0.6186000000000003+-0.015784521250614813,0.6186000000000003+-0.015784521250614813,0.6186000000000003+-0.015784521250614813
10,2,0.6617025555555552+-0.016699942696360617,0.5367600359032807+-0.024389511195308994,0.4924619999999997+-0.021737970374439284,0.7462380000000003+-0.010870517742959634,0.6616479999999999+-0.014495119730447198,8.69+-1.238507165905794,0.4924666666666663+-0.021738956327805217,0.4924666666666663+-0.021738956327805217,0.4924666666666663+-0.02173895632780522
10,5,0.6675254222222219+-0.012687019026733233,0.34412226680339364+-0.022344548647385283,0.3454679999999999+-0.024320534040189162,0.8690920000000003+-0.004863222799749143,0.7818230000000003+-0.008109831749179511,7.03+-1.4795607456268909,0.34546666666666664+-0.02431972222063585,0.34546666666666664+-0.02431972222063585,0.34546666666666664+-0.02431972222063585
10,10,0.6653817777777781+-0.008872209150480885,0.23081562492206426+-0.02382433345386149,0.24559799999999987+-0.01602071146984428,0.9245580000000007+-0.001602946037769194,0.8628289999999995+-0.0029151087458275013,6.13+-1.677229859023503,0.24559999999999985+-0.016019987515600628,0.24559999999999985+-0.016019987515600628,0.24559999999999985+-0.016019987515600624
This diff is collapsed.
train_rate,test_rate,auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1
1,1,0.6185150970301941+-0.02250645114909087,0.6259895888548804+-0.05141877336973243,0.5893069999999999+-0.010203100068116549,0.5893069999999999+-0.010203100068116549,0.5893069999999999+-0.010203100068116549,5.21+-2.6806529055437216,0.5892913385826773+-0.01020238029739546,0.5892913385826773+-0.01020238029739546,0.5892913385826773+-0.01020238029739546
1,2,0.6160012331135772+-0.021506175779113783,0.4675118410542188+-0.06463749138261841,0.5273540000000005+-0.021910524503078413,0.7636770000000002+-0.010955289635605265,0.6848830000000002+-0.014603602671943654,3.98+-2.8248185782453357,0.5273490813648293+-0.02190517908588862,0.5273490813648293+-0.02190517908588862,0.5273490813648293+-0.02190517908588862
1,5,0.6144028630279484+-0.020643123513162512,0.2694625547551436+-0.06437234637143585,0.3506820000000001+-0.06406164434355396,0.8701319999999997+-0.012815450674868986,0.7835630000000005+-0.02134597224302515,2.4+-2.2934689882359427,0.3506824146981627+-0.06406231659845604,0.3506824146981627+-0.06406231659845604,0.35068241469816264+-0.06406231659845603
1,10,0.6177359759163966+-0.01968174660762131,0.1670023334941839+-0.05551446527942811,0.18021599999999988+-0.10304090810935239,0.9180169999999991+-0.010309859892355474,0.850944+-0.01873375733802485,1.73+-2.0042704408337713,0.18020997375328063+-0.1030438439180181,0.18020997375328063+-0.1030438439180181,0.18020997375328063+-0.10304384391801809
2,1,0.5993376320085975+-0.01985356463872533,0.5547601499325895+-0.02354935713986705,0.5996549999999999+-0.014972851264872704,0.5996549999999999+-0.014972851264872704,0.5996549999999999+-0.014972851264872704,1.05+-0.21794494717703397,0.5996587926509185+-0.01498342877632369,0.5996587926509185+-0.01498342877632369,0.5996587926509185+-0.01498342877632369
2,2,0.5969309938619877+-0.019488263190210887,0.38850888875575357+-0.022509995395467025,0.5004189999999998+-0.026708184869062148,0.750208+-0.01335696582312017,0.6669330000000002+-0.017806471042854054,0.88+-0.3249615361854386,0.5004199475065617+-0.02670523385458212,0.5004199475065617+-0.02670523385458212,0.5004199475065617+-0.026705233854582126
2,5,0.5940216793766921+-0.016324511591024094,0.20471755343582937+-0.01471271470880742,0.19194499999999992+-0.08201144356124941,0.8359440000000004+-0.016382846639091755,0.728618+-0.02687250408875214,0.74+-0.4386342439892259,0.18938362369401912+-0.08059521589082436,0.1919422572178478+-0.0820138831560355,0.19060780588916132+-0.08119777005459748
2,10,0.5971556237556922+-0.014759525910560624,0.11669690499807668+-0.009469006670052215,0.02285399999999999+-0.017093878553447135,0.9022519999999998+-0.001697850405660045,0.8223279999999993+-0.0030891124939049965,0.29+-0.45376205218153703,0.022852770711899627+-0.01709366310938665,0.022860892388451418+-0.017103901652943704,0.022856812611217307+-0.017098735165455164
5,1,0.6745069267916314+-0.010798732615492525,0.6192210819376948+-0.016218817967040854,0.6753010000000002+-0.015465707193659138,0.6753010000000002+-0.015465707193659138,0.6753010000000002+-0.015465707193659138,1.02+-0.13999999999999987,0.6753018372703413+-0.015457936117305214,0.6753018372703413+-0.015457936117305214,0.6753018372703413+-0.015457936117305214
5,2,0.6754435764427085+-0.009243884479907014,0.4614711022431336+-0.013652597060224413,0.5681400000000003+-0.010357586591479694,0.7840749999999995+-0.005178520541621905,0.7120820000000003+-0.006903613256838772,1.0+-0.0,0.5681364829396326+-0.010356118591256203,0.5681364829396326+-0.010356118591256203,0.5681364829396326+-0.010356118591256203
5,5,0.674936546317537+-0.006275937875305849,0.2624580192098278+-0.006879979981139013,0.3477499999999999+-0.01719315852308702,0.8695469999999998+-0.003440449243921496,0.78259+-0.00573341957299481,1.0+-0.0,0.34774278215223103+-0.01718907637088882,0.34774278215223103+-0.01718907637088882,0.34774278215223103+-0.017189076370888817
5,10,0.6757665901998469+-0.00335673092192169,0.15493161773146613+-0.0028202715839703084,0.059781999999999974+-0.025259637289557414,0.9059749999999999+-0.002535759255134444,0.8290510000000001+-0.004589106557926052,0.7+-0.4582575694955838,0.05979002624671912+-0.025261236633628063,0.05979002624671912+-0.025261236633628063,0.05979002624671912+-0.025261236633628063
10,1,0.5588304021052487+-0.03226317326215112,0.49977669857180046+-0.020866921943232798,0.5541019999999999+-0.04210576915340699,0.5541019999999999+-0.04210576915340699,0.5541019999999999+-0.04210576915340699,1.35+-0.4769696007084726,0.5540944881889761+-0.04210761572041523,0.5540944881889761+-0.04210761572041523,0.5540944881889761+-0.04210761572041523
10,2,0.5567749946611003+-0.03374624869410767,0.33737715660891127+-0.01888212992059675,0.329268+-0.058072290948437716,0.6646270000000001+-0.029035191940126732,0.5528609999999999+-0.03870975818834316,1.16+-0.36660605559646736,0.329265091863517+-0.058069197894094315,0.329265091863517+-0.058069197894094315,0.329265091863517+-0.058069197894094315
10,5,0.5554094832634108+-0.034560535313144514,0.17184780471507766+-0.012371075571753452,0.104156+-0.014002559194661517,0.8208279999999997+-0.0027922062961034996,0.7013859999999998+-0.0046669908934987145,1.0+-0.0,0.1041469816272966+-0.014011323728841086,0.1041469816272966+-0.014011323728841086,0.1041469816272966+-0.014011323728841086
10,10,0.5548342805574502+-0.03433533589151028,0.0947685668243906+-0.007544039329290581,0.019303999999999998+-0.005780033217897627,0.9019320000000001+-0.0005789438660181095,0.8216879999999999+-0.001062570468251399,0.98+-0.1399999999999999,0.019317585301837265+-0.0057804786561783655,0.019317585301837265+-0.0057804786561783655,0.019317585301837265+-0.0057804786561783655
This diff is collapsed.
train_rate,test_rate,auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1
1,1,0.7472829478992292+-0.017303946635186855,0.7954404763667392+-0.01691380864768139,0.6709400000000003+-0.016239033222455086,0.6709400000000003+-0.016239033222455086,0.6709400000000003+-0.016239033222455086,9.8+-0.4472135954999577,0.6709448818897635+-0.016233357095786473,0.6709448818897635+-0.016233357095786473,0.6709448818897635+-0.016233357095786473
1,2,0.7438804327608659+-0.01530592379772738,0.6925858470889119+-0.021603330859805685,0.5944339999999998+-0.01950446728316362,0.7971359999999997+-0.009799933877327945,0.7295659999999995+-0.01302994412881344,9.75+-0.49749371855331,0.5943441372154059+-0.019548666086055957,0.5944356955380574+-0.019509205949340663,0.5943895587270338+-0.01952378771956089
1,5,0.7445964756373955+-0.014659717378898006,0.5442394710245924+-0.025330933421598388,0.5136540000000004+-0.023241834781273202,0.9027310000000002+-0.004652175727549412,0.8378760000000005+-0.0077506402316195894,9.26+-0.8077128202523468,0.5136482939632545+-0.023244281794644193,0.5136482939632545+-0.023244281794644193,0.5136482939632545+-0.0232442817946442
1,10,0.7462009527352385+-0.014552064540475248,0.43713094832381694+-0.02908713819826009,0.4507879999999996+-0.02601482761811041,0.9450790000000002+-0.002603086437289384,0.9001420000000008+-0.004726376624857557,8.8+-0.9273618495495706,0.45078740157480274+-0.026021011142011913,0.45078740157480274+-0.026021011142011913,0.45078740157480274+-0.026021011142011913
2,1,0.744802736272139+-0.01489964616623676,0.7970251618976454+-0.012537634437984207,0.6622299999999999+-0.013419027535555626,0.6622299999999999+-0.013419027535555626,0.6622299999999999+-0.013419027535555626,9.89+-0.343365694267788,0.6622309711286087+-0.01341767093802498,0.6622309711286087+-0.01341767093802498,0.6622309711286087+-0.01341767093802498
2,2,0.7405814061628126+-0.012519462103028177,0.6967501434567375+-0.01351288557969554,0.5849960000000003+-0.014486862462244891,0.7924929999999997+-0.007250796576928627,0.7233179999999997+-0.009669295527596624,9.95+-0.21794494717703367,0.5849868766404202+-0.014496555080056336,0.5849868766404202+-0.014496555080056336,0.5849868766404202+-0.014496555080056336
2,5,0.7418936766762421+-0.011941126085811543,0.5610403643535282+-0.01566049485116663,0.504225+-0.012619384889922338,0.9008480000000002+-0.0025254496629313265,0.8347420000000004+-0.004212010921163436,9.76+-0.47159304490206433,0.5042257217847772+-0.012619705343563698,0.5042257217847772+-0.012619705343563698,0.5042257217847772+-0.012619705343563698
2,10,0.7431868373736745+-0.011820990808528746,0.46106418593818643+-0.02013186609463637,0.4553409999999994+-0.011273926512089739,0.9454179999999996+-0.0011899899159236592,0.9008820000000008+-0.0020070067264461194,9.38+-0.8919641248391105,0.45485783103040617+-0.011026274734129269,0.4553543307086614+-0.011267732125716665,0.4551008034860575+-0.011040318451284172
5,1,0.7289834735225027+-0.016301647032175216,0.784203589000247+-0.01261862596871064,0.6412529999999995+-0.01538260351826049,0.6412529999999995+-0.01538260351826049,0.6412529999999995+-0.01538260351826049,9.99+-0.09949874371066197,0.641259842519685+-0.015375649779913228,0.641259842519685+-0.015375649779913228,0.641259842519685+-0.015375649779913228
5,2,0.7244984534413512+-0.01301244556398768,0.6795859135845076+-0.014594208792138434,0.5539970000000004+-0.015338568740270396,0.777001+-0.0076729850123664406,0.7026410000000003+-0.010228461223468559,9.99+-0.09949874371066199,0.553989501312336+-0.01534144870922805,0.553989501312336+-0.01534144870922805,0.553989501312336+-0.01534144870922805
5,5,0.7258370981186405+-0.012263082772431827,0.5443301033240082+-0.016945998147624982,0.48382999999999954+-0.015691504070674683,0.896765+-0.00314399029896723,0.8279340000000002+-0.005226857947180126,9.81+-0.4403407771260798,0.4838320209973751+-0.01568763571817151,0.4838320209973751+-0.01568763571817151,0.4838320209973751+-0.01568763571817151
5,10,0.7273939453434459+-0.012355738617676283,0.44939037749101096+-0.01972706154140244,0.4372319999999994+-0.013790626381713047,0.9437180000000002+-0.001382706042512301,0.8976790000000001+-0.0025106292040044465,9.61+-0.6147357155721471,0.437244094488189+-0.0137843760998637,0.437244094488189+-0.0137843760998637,0.437244094488189+-0.013784376099863696
10,1,0.7191138115609566+-0.014191295707761821,0.7746464516847421+-0.011867878875525813,0.6333769999999995+-0.012769055211721824,0.6333769999999995+-0.012769055211721824,0.6333769999999995+-0.012769055211721824,9.99+-0.09949874371066199,0.6333858267716529+-0.012766162451088633,0.6333858267716529+-0.012766162451088633,0.6333858267716529+-0.012766162451088633
10,2,0.715078654046197+-0.012109302091690607,0.6664210030460502+-0.01541338421526991,0.5455740000000003+-0.010826233139924523,0.7727859999999995+-0.005416770624643425,0.6970290000000005+-0.00722231673356964,10.0+-0.0,0.5455643044619422+-0.010828290675272918,0.5455643044619422+-0.010828290675272918,0.5455643044619422+-0.010828290675272918
10,5,0.7160970784163792+-0.010819394303880143,0.5290099799446518+-0.018166510093359598,0.4681549999999996+-0.01620000231481465,0.8936359999999997+-0.003245720875244815,0.8227149999999994+-0.005402487852832244,9.92+-0.3059411708155672,0.46816272965879213+-0.016199372153656126,0.46816272965879213+-0.016199372153656126,0.46816272965879213+-0.01619937215365613
10,10,0.7171326768209092+-0.01051557284209483,0.4322265484581227+-0.020685225182008764,0.41623600000000016+-0.016823177583322364,0.9416270000000008+-0.0016850433228852014,0.8938639999999987+-0.003062728195579873,9.81+-0.48363209157374976,0.41624671916010536+-0.016831065520920934,0.41624671916010536+-0.016831065520920934,0.41624671916010536+-0.016831065520920934
...@@ -28,6 +28,17 @@ if __name__ == "__main__": ...@@ -28,6 +28,17 @@ if __name__ == "__main__":
train_data = pos_train_features + neg_train_features train_data = pos_train_features + neg_train_features
train_lbl = [1] * len(pos_train_features) + [0] * len(neg_train_features) train_lbl = [1] * len(pos_train_features) + [0] * len(neg_train_features)
if args.data_dir.find('sar') > 0:
top_protein_save_path = args.data_dir + 'denovo_sar_top_predicted.csv'
tmp_idx = 0
while os.path.exists(top_protein_save_path):
top_protein_save_path = top_protein_save_path.replace(str(tmp_idx) + '.csv', '.csv')
tmp_idx = tmp_idx + 1
top_protein_save_path = top_protein_save_path.replace('.csv', str(tmp_idx) + '.csv')
sar_top_writer = open(top_protein_save_path, 'w')
hdict = pd.read_csv(args.data_dir + 'hprots.csv', header=None).values.tolist()
hdict = {i:item[0] for i, item in enumerate(hdict)}
clf = SVC(kernel='rbf', C=10, gamma=0.001, probability=True) clf = SVC(kernel='rbf', C=10, gamma=0.001, probability=True)
print('Start training ...') print('Start training ...')
clf.fit(train_data, train_lbl) clf.fit(train_data, train_lbl)
...@@ -41,7 +52,15 @@ if __name__ == "__main__": ...@@ -41,7 +52,15 @@ if __name__ == "__main__":
test_lbl = [1] * len(pos_test_features) + [0] * len(neg_test_features) test_lbl = [1] * len(pos_test_features) + [0] * len(neg_test_features)
print('Train pairs: ', len(train_data), 'Test pairs: ', len(test_data)) print('Train pairs: ', len(train_data), 'Test pairs: ', len(test_data))
preds = clf.predict_proba(test_data)[:,1] preds = clf.predict_proba(test_data)[:,1]
if args.data_dir.find('sar') > 0:
test_pairs = get_test_pairs(args.pos_test_path, args.neg_test_path)
sorted_scores = sorted(preds.tolist(), reverse=True)
for ik in range(10):
rank = ik+1
real_idx = preds.tolist().index(sorted_scores[ik])
hprot = test_pairs[real_idx][1]
sar_top_writer.write(hdict[hprot] + ',' + str(rank) + '\n')
sar_top_writer.close()
print('Finish testing ...') print('Finish testing ...')
auc_score, aupr_score,sn, sp, acc, topk, precision, recall, f1 = get_score(test_lbl, preds) auc_score, aupr_score,sn, sp, acc, topk, precision, recall, f1 = get_score(test_lbl, preds)
......
...@@ -52,6 +52,16 @@ def main(): ...@@ -52,6 +52,16 @@ def main():
h_features = feature_ecode('doc2vec',sampleppi_dir,['all'],hseq_dict,ecode,0,5000,doc2vec_model_dir,False,modelname,ecodename) h_features = feature_ecode('doc2vec',sampleppi_dir,['all'],hseq_dict,ecode,0,5000,doc2vec_model_dir,False,modelname,ecodename)
train_data, train_lbl= load_data(args.pos_train_path, args.neg_train_path, v_features, h_features) train_data, train_lbl= load_data(args.pos_train_path, args.neg_train_path, v_features, h_features)
if args.data_dir.find('sar') > 0:
top_protein_save_path = args.data_dir + 'doc2vec_sar_top_predicted.csv'
tmp_idx = 0
while os.path.exists(top_protein_save_path):
top_protein_save_path = top_protein_save_path.replace(str(tmp_idx) + '.csv', '.csv')
tmp_idx = tmp_idx + 1
top_protein_save_path = top_protein_save_path.replace('.csv', str(tmp_idx) + '.csv')
sar_top_writer = open(top_protein_save_path, 'w')
hdict = pd.read_csv(args.data_dir + 'hprots.csv', header=None).values.tolist()
hdict = {i:item[0] for i, item in enumerate(hdict)}
for irun in range(10): for irun in range(10):
clf = RandomForestClassifier(n_estimators=1500, criterion='entropy') clf = RandomForestClassifier(n_estimators=1500, criterion='entropy')
clf.fit(train_data, train_lbl) clf.fit(train_data, train_lbl)
...@@ -62,6 +72,15 @@ def main(): ...@@ -62,6 +72,15 @@ def main():
print('Finish testing ...') print('Finish testing ...')
if args.data_dir.find('sar') > 0:
test_pairs = get_test_pairs(args.pos_test_path, args.neg_test_path)
sorted_scores = sorted(preds.tolist(), reverse=True)
for ik in range(10):
rank = ik+1
real_idx = preds.tolist().index(sorted_scores[ik])
hprot = test_pairs[real_idx][1]
sar_top_writer.write(hdict[hprot] + ',' + str(rank) + '\n')
auc_score, aupr_score,sn, sp, acc, topk, precision, recall, f1 = get_score(test_lbl, preds) auc_score, aupr_score,sn, sp, acc, topk, precision, recall, f1 = get_score(test_lbl, preds)
print('Run', (irun+1), 'performance: ') print('Run', (irun+1), 'performance: ')
print('auc_score, aupr_score,sn, sp, acc, topk, precision, recall, f1:') print('auc_score, aupr_score,sn, sp, acc, topk, precision, recall, f1:')
......
...@@ -30,6 +30,17 @@ if __name__ == "__main__": ...@@ -30,6 +30,17 @@ if __name__ == "__main__":
train_lbl = [1] * len(pos_train_features) + [0] * len(neg_train_features) train_lbl = [1] * len(pos_train_features) + [0] * len(neg_train_features)
if args.data_dir.find('sar') > 0:
top_protein_save_path = args.data_dir + 'generalized_sar_top_predicted.csv'
tmp_idx = 0
while os.path.exists(top_protein_save_path):
top_protein_save_path = top_protein_save_path.replace(str(tmp_idx) + '.csv', '.csv')
tmp_idx = tmp_idx + 1
top_protein_save_path = top_protein_save_path.replace('.csv', str(tmp_idx) + '.csv')
sar_top_writer = open(top_protein_save_path, 'w')
hdict = pd.read_csv(args.data_dir + 'hprots.csv', header=None).values.tolist()
hdict = {i:item[0] for i, item in enumerate(hdict)}
clf = SVC(kernel='rbf', C=32, gamma=0.03125, probability=True) clf = SVC(kernel='rbf', C=32, gamma=0.03125, probability=True)
print('Start training ...') print('Start training ...')
clf.fit(train_data, train_lbl) clf.fit(train_data, train_lbl)
...@@ -44,6 +55,15 @@ if __name__ == "__main__": ...@@ -44,6 +55,15 @@ if __name__ == "__main__":
print('Train pairs: ', len(train_data), 'Test pairs: ', len(test_data)) print('Train pairs: ', len(train_data), 'Test pairs: ', len(test_data))
preds = clf.predict_proba(test_data)[:,1] preds = clf.predict_proba(test_data)[:,1]
if args.data_dir.find('sar') > 0:
test_pairs = get_test_pairs(args.pos_test_path, args.neg_test_path)
sorted_scores = sorted(preds.tolist(), reverse=True)
for ik in range(10):
rank = ik+1
real_idx = preds.tolist().index(sorted_scores[ik])
hprot = test_pairs[real_idx][1]
sar_top_writer.write(hdict[hprot] + ',' + str(rank) + '\n')
sar_top_writer.close()
print('Finish testing ...') print('Finish testing ...')
auc_score, aupr_score,sn, sp, acc, topk, precision, recall, f1 = get_score(test_lbl, preds) auc_score, aupr_score,sn, sp, acc, topk, precision, recall, f1 = get_score(test_lbl, preds)
......
...@@ -30,7 +30,6 @@ def parse_file(path): ...@@ -30,7 +30,6 @@ def parse_file(path):
for seq in seqs: for seq in seqs:
avg_hidden, final_hidden, final_cell = b.get_rep(seq) avg_hidden, final_hidden, final_cell = b.get_rep(seq)
final_rep.append(final_hidden) final_rep.append(final_hidden)
break
res_df = pd.DataFrame(np.array(final_rep)) res_df = pd.DataFrame(np.array(final_rep))
res_df.to_csv(path.replace('.csv', '_1900emb.csv'), index=False, header=False) res_df.to_csv(path.replace('.csv', '_1900emb.csv'), index=False, header=False)
......
import pandas as pd
from model import * from model import *
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import argparse import argparse
...@@ -9,7 +11,7 @@ def read_int_data(pos_path, neg_path): ...@@ -9,7 +11,7 @@ def read_int_data(pos_path, neg_path):
neg_df = pd.read_csv(neg_path).values.tolist() neg_df = pd.read_csv(neg_path).values.tolist()
int_edges = pos_df + neg_df int_edges = pos_df + neg_df
int_lbl = [1] * len(pos_df) + [0] * len(neg_df) int_lbl = [1] * len(pos_df) + [0] * len(neg_df)
return pos_df, t.LongTensor(int_edges), t.FloatTensor(int_lbl) return int_edges, t.LongTensor(int_edges), t.FloatTensor(int_lbl)
def read_train_data(pos_path, neg_path, fixval=False): def read_train_data(pos_path, neg_path, fixval=False):
pos_df = pd.read_csv(pos_path).values.tolist() pos_df = pd.read_csv(pos_path).values.tolist()
...@@ -94,7 +96,7 @@ def main(): ...@@ -94,7 +96,7 @@ def main():
hindex_tensor = t.LongTensor(list(range(n_human))) hindex_tensor = t.LongTensor(list(range(n_human)))
pos_train_pairs, val_data, train_tensor, train_lbl_tensor, val_tensor, val_lbl_tensor = read_train_data(args.pos_train_path, args.neg_train_path, args.fixval) pos_train_pairs, val_data, train_tensor, train_lbl_tensor, val_tensor, val_lbl_tensor = read_train_data(args.pos_train_path, args.neg_train_path, args.fixval)
_, test_tensor, test_lbl_tensor = read_int_data(args.pos_test_path, args.neg_test_path) test_pairs, test_tensor, test_lbl_tensor = read_int_data(args.pos_test_path, args.neg_test_path)
test_lbl = test_lbl_tensor.detach().numpy() test_lbl = test_lbl_tensor.detach().numpy()
val_lbl = val_lbl_tensor.detach().numpy() val_lbl = val_lbl_tensor.detach().numpy()
...@@ -130,18 +132,30 @@ def main(): ...@@ -130,18 +132,30 @@ def main():
model_prefix = args.data_dir + negname + '_' model_prefix = args.data_dir + negname + '_'
performance_dict = dict() performance_dict = dict()
val_performance_dict = dict() val_performance_dict = dict()
for epochs in grid_epochs:
args.epochs = epochs for lr in lrs:
for weight in ppi_weights: for weight in ppi_weights:
args.ppi_weight = weight args.ppi_weight = weight
for hid in hiddens: for hid in hiddens:
for lr in lrs: for epochs in grid_epochs:
args.epochs = epochs
all_scores = list() all_scores = list()
params = [epochs, weight, hid, lr, 0] params = [epochs, weight, hid, lr, 0]
params = [str(item) for item in params] params = [str(item) for item in params]
save_model_prefix = model_prefix + '_'.join(params) + '.model' save_model_prefix = model_prefix + '_'.join(params) + '.model'
val_all_scores = list() val_all_scores = list()
if args.data_dir.find('sar') > 0:
top_protein_save_path = args.data_dir + '_'.join(params) + '_mtt_sar_top_predicted.csv'
tmp_idx = 0
while os.path.exists(top_protein_save_path):
top_protein_save_path = top_protein_save_path.replace(str(tmp_idx) + '.csv', '.csv')
tmp_idx = tmp_idx + 1
top_protein_save_path = top_protein_save_path.replace('.csv', str(tmp_idx) + '.csv')
sar_top_writer = open(top_protein_save_path, 'w')
hdict = pd.read_csv(args.data_dir + 'hprots.csv', header=None).values.tolist()
hdict = {i:item[0] for i, item in enumerate(hdict)}
for irun in range(args.n_runs): for irun in range(args.n_runs):
save_model_path = save_model_prefix.replace('0.model', str(irun) + '.model') save_model_path = save_model_prefix.replace('0.model', str(irun) + '.model')
model = Model(n_virus, n_human, hid) model = Model(n_virus, n_human, hid)
...@@ -152,7 +166,7 @@ def main(): ...@@ -152,7 +166,7 @@ def main():
if t.cuda.is_available(): if t.cuda.is_available():
model = model.cuda() model = model.cuda()
best_ap = 0 best_ap = 0
for epoch in range(0, args.epochs): for epoch in range(0, epochs):
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
score, hppi_out = model(vindex_tensor, hindex_tensor, score, hppi_out = model(vindex_tensor, hindex_tensor,
...@@ -171,7 +185,7 @@ def main(): ...@@ -171,7 +185,7 @@ def main():
val_pred_lbl = pred_score.tolist() val_pred_lbl = pred_score.tolist()
val_pred_lbl = [item[0] if type(item) == list else item for item in val_pred_lbl] val_pred_lbl = [item[0] if type(item) == list else item for item in val_pred_lbl]
auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1 = get_score(val_lbl, val_pred_lbl,K=1) auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1 = get_score(val_lbl, val_pred_lbl)
print('Validation set lr:%.4f, auc:%.4f, aupr:%.4f' %(lr, auc_score, aupr_score)) print('Validation set lr:%.4f, auc:%.4f, aupr:%.4f' %(lr, auc_score, aupr_score))
if best_ap < aupr_score: if best_ap < aupr_score:
best_ap = aupr_score best_ap = aupr_score
...@@ -191,10 +205,36 @@ def main(): ...@@ -191,10 +205,36 @@ def main():
test_pred_lbl = pred_score.tolist() test_pred_lbl = pred_score.tolist()
test_pred_lbl = [item[0] if type(item) == list else item for item in test_pred_lbl] test_pred_lbl = [item[0] if type(item) == list else item for item in test_pred_lbl]
auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1 = get_score(test_lbl, test_pred_lbl) # joint_list = [[pred,target] for pred, target in zip(test_pred_lbl, test_lbl)]
print('lr:%.4f, auc:%.4f, aupr:%.4f' %(lr, auc_score, aupr_score)) # df = pd.DataFrame(np.array(joint_list), columns=['pred', 'target'])
# df.to_csv(save_model_path.replace('.model', '.scores'), index=False)
all_scores.append([auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1])#, sn, sp, acc, topk]) # only for sars, get the top highest prediction scores
if args.data_dir.find('sar') > 0:
sorted_scores = sorted(test_pred_lbl, reverse=True)
c = 0
for i in range(10):
rank = i+1
val = sorted_scores[i]
found_indexes = [j for j,item in enumerate(test_pred_lbl) if item == val]
hprots = [test_pairs[real_idx][1] for real_idx in found_indexes]
for hprot in hprots:
sar_top_writer.write(hdict[hprot] + ',' + str(rank) + '\n')
c = c+1
if c == 10:
break
if c == 10:
break
# auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1 = get_score(test_lbl, test_pred_lbl)
# print('lr:%.4f, auc:%.4f, aupr:%.4f' %(lr, auc_score, aupr_score))
#
# all_scores.append([auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1])#, sn, sp, acc, topk])
topks = list()
for k in range(1,11):
auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1 = get_score(test_lbl, test_pred_lbl,K=k)
topks.append(topk)
all_scores.append([auc_score, aupr_score, sn, sp, acc, topks[0], precision, recall, f1] + topks)
print(params, all_scores[-1])
# Save the performance on the validation set also # Save the performance on the validation set also
pred_score = model.infer(vindex_tensor, hindex_tensor, pred_score = model.infer(vindex_tensor, hindex_tensor,
...@@ -204,10 +244,12 @@ def main(): ...@@ -204,10 +244,12 @@ def main():
test_pred_lbl = pred_score.tolist() test_pred_lbl = pred_score.tolist()
test_pred_lbl = [item[0] if type(item) == list else item for item in test_pred_lbl] test_pred_lbl = [item[0] if type(item) == list else item for item in test_pred_lbl]
auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1 = get_score(val_lbl, test_pred_lbl,K=1) auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1 = get_score(val_lbl, test_pred_lbl)
val_all_scores.append([auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1])#, sn, sp, acc, topk]) val_all_scores.append([auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1])#, sn, sp, acc, topk])
if max_auc[0] < auc_score: if max_auc[0] < auc_score:
max_auc = [auc_score, aupr_score] max_auc = [auc_score, aupr_score]
if args.data_dir.find('sar') > 0:
sar_top_writer.close()
t.cuda.empty_cache() t.cuda.empty_cache()
arr = np.array(all_scores) arr = np.array(all_scores)
print('all_scores: ', all_scores) print('all_scores: ', all_scores)
......
...@@ -9,7 +9,7 @@ def read_int_data(pos_path, neg_path): ...@@ -9,7 +9,7 @@ def read_int_data(pos_path, neg_path):
neg_df = pd.read_csv(neg_path).values.tolist() neg_df = pd.read_csv(neg_path).values.tolist()
int_edges = pos_df + neg_df int_edges = pos_df + neg_df
int_lbl = [1] * len(pos_df) + [0] * len(neg_df) int_lbl = [1] * len(pos_df) + [0] * len(neg_df)
return pos_df, t.LongTensor(int_edges), t.FloatTensor(int_lbl) return int_edges, t.LongTensor(int_edges), t.FloatTensor(int_lbl)
def read_train_data(pos_path, neg_path, fixval=False): def read_train_data(pos_path, neg_path, fixval=False):
pos_df = pd.read_csv(pos_path).values.tolist() pos_df = pd.read_csv(pos_path).values.tolist()
...@@ -84,7 +84,7 @@ def main(): ...@@ -84,7 +84,7 @@ def main():
hindex_tensor = t.LongTensor(list(range(n_human))) hindex_tensor = t.LongTensor(list(range(n_human)))
pos_train_pairs, val_data, train_tensor, train_lbl_tensor, val_tensor, val_lbl_tensor = read_train_data(args.pos_train_path, args.neg_train_path, args.fixval) pos_train_pairs, val_data, train_tensor, train_lbl_tensor, val_tensor, val_lbl_tensor = read_train_data(args.pos_train_path, args.neg_train_path, args.fixval)
_, test_tensor, test_lbl_tensor = read_int_data(args.pos_test_path, args.neg_test_path) test_pairs, test_tensor, test_lbl_tensor = read_int_data(args.pos_test_path, args.neg_test_path)
test_lbl = test_lbl_tensor.detach().numpy() test_lbl = test_lbl_tensor.detach().numpy()
val_lbl = val_lbl_tensor.detach().numpy() val_lbl = val_lbl_tensor.detach().numpy()
...@@ -107,7 +107,7 @@ def main(): ...@@ -107,7 +107,7 @@ def main():
max_auc = [0,0] max_auc = [0,0]
lrs = [0.001, 0.01] lrs = [0.001, 0.01]
grid_epochs = [200] grid_epochs = [100]
hiddens = [8,16,32,64] hiddens = [8,16,32,64]
model_prefix = args.data_dir + negname + '_' model_prefix = args.data_dir + negname + '_'
performance_dict = dict() performance_dict = dict()
...@@ -121,6 +121,16 @@ def main(): ...@@ -121,6 +121,16 @@ def main():
params = [str(item) for item in params] params = [str(item) for item in params]
save_model_prefix = model_prefix + '_'.join(params) + '.model' save_model_prefix = model_prefix + '_'.join(params) + '.model'
val_all_scores = list() val_all_scores = list()
if args.data_dir.find('sar') > 0:
top_protein_save_path = args.data_dir + '_'.join(params) + '_stt_sar_top_predicted.csv'
tmp_idx = 0
while os.path.exists(top_protein_save_path):
top_protein_save_path = top_protein_save_path.replace(str(tmp_idx) + '.csv', '.csv')
tmp_idx = tmp_idx + 1
top_protein_save_path = top_protein_save_path.replace('.csv', str(tmp_idx) + '.csv')
sar_top_writer = open(top_protein_save_path, 'w')
hdict = pd.read_csv(args.data_dir + 'hprots.csv', header=None).values.tolist()
hdict = {i:item[0] for i, item in enumerate(hdict)}
for irun in range(args.n_runs): for irun in range(args.n_runs):
save_model_path = save_model_prefix.replace('0.model', str(irun) + '.model') save_model_path = save_model_prefix.replace('0.model', str(irun) + '.model')
...@@ -152,8 +162,7 @@ def main(): ...@@ -152,8 +162,7 @@ def main():
val_pred_lbl = pred_score.tolist() val_pred_lbl = pred_score.tolist()
val_pred_lbl = [item[0] if type(item) == list else item for item in val_pred_lbl] val_pred_lbl = [item[0] if type(item) == list else item for item in val_pred_lbl]
auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1 = get_score(val_lbl, auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1 = get_score(val_lbl,
val_pred_lbl, val_pred_lbl)
K=1)
print('Validation set lr:%.4f, auc:%.4f, aupr:%.4f' % (lr, auc_score, aupr_score)) print('Validation set lr:%.4f, auc:%.4f, aupr:%.4f' % (lr, auc_score, aupr_score))
if best_ap < aupr_score: if best_ap < aupr_score:
best_ap = aupr_score best_ap = aupr_score
...@@ -171,6 +180,17 @@ def main(): ...@@ -171,6 +180,17 @@ def main():
pred_score = pred_score.detach().numpy() if not t.cuda.is_available() else pred_score.cpu().detach().numpy() pred_score = pred_score.detach().numpy() if not t.cuda.is_available() else pred_score.cpu().detach().numpy()
test_pred_lbl = pred_score.tolist() test_pred_lbl = pred_score.tolist()
joint_list = [[pred,target] for pred, target in zip(test_pred_lbl, test_lbl)]
df = pd.DataFrame(np.array(joint_list), columns=['pred', 'target'])
df.to_csv(save_model_path.replace('.model', '.scores'), index=False)
# only for sars, get the top highest prediction scores
if args.data_dir.find('sar') > 0:
sorted_scores = sorted(test_pred_lbl, reverse=True)
for ik in range(10):
rank = ik+1
real_idx = test_pred_lbl.index(sorted_scores[ik])
hprot = test_pairs[real_idx][1]
sar_top_writer.write(hdict[hprot] + ',' + str(rank) + '\n')
test_pred_lbl = [item[0] if type(item) == list else item for item in test_pred_lbl] test_pred_lbl = [item[0] if type(item) == list else item for item in test_pred_lbl]
auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1 = get_score(test_lbl, test_pred_lbl) auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1 = get_score(test_lbl, test_pred_lbl)
print('lr:%.4f, auc:%.4f, aupr:%.4f' % (lr, auc_score, aupr_score)) print('lr:%.4f, auc:%.4f, aupr:%.4f' % (lr, auc_score, aupr_score))
...@@ -190,6 +210,8 @@ def main(): ...@@ -190,6 +210,8 @@ def main():
[auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1]) # , sn, sp, acc, topk]) [auc_score, aupr_score, sn, sp, acc, topk, precision, recall, f1]) # , sn, sp, acc, topk])
if max_auc[0] < auc_score: if max_auc[0] < auc_score:
max_auc = [auc_score, aupr_score] max_auc = [auc_score, aupr_score]
if args.data_dir.find('sar') > 0:
sar_top_writer.close()
t.cuda.empty_cache() t.cuda.empty_cache()
arr = np.array(all_scores) arr = np.array(all_scores)
print('all_scores: ', all_scores) print('all_scores: ', all_scores)
......
...@@ -110,3 +110,8 @@ def get_score2(targets, preds, K=10): ...@@ -110,3 +110,8 @@ def get_score2(targets, preds, K=10):
aupr_score = metrics.average_precision_score(targets, preds, average='micro') aupr_score = metrics.average_precision_score(targets, preds, average='micro')
return auc_score, aupr_score return auc_score, aupr_score
def get_test_pairs(pos_path, neg_path):
pos_df = pd.read_csv(pos_path).values.tolist()
neg_df = pd.read_csv(neg_path).values.tolist()
int_edges = pos_df + neg_df
return int_edges
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment